# -*- coding: utf-8 -*-

import torch.nn as nn


class TokenEmbedding(nn.Embedding):
    """token嵌入"""

    def __init__(self, vocab_size, embed_size=512):
        """
        随机初始化 vocab_size x embed_size的向量矩阵
        :param vocab_size: 字典大小
        :param embed_size: 嵌入的维度
        """
        super().__init__(vocab_size, embed_size, padding_idx=0)
